#pragma once
#include <string.h>  // for memset
#include <assert.h>

#define TIMER_MAX_TIMEOUT (unsigned long long)~0
#define TIMER_HEAP_SWAP(_type, _a, _b) \
    do {                               \
        _type tmp = (_a);              \
        (_a) = (_b);                   \
        (_b) = tmp;                    \
    } while (0)
typedef struct htimer_s htimer_t;
typedef struct htimer_mgr_s htimer_mgr_t;
typedef void (*timer_cb)(htimer_t* timer);
typedef void (*timer_debug_dump_cb)(htimer_t* timer);  // only for debug
typedef int (*timer_heap_compare)(htimer_t* lhs, htimer_t* rhs);
struct htimer_s
{
    htimer_t *left, *right, *parent;
    htimer_mgr_t* timer_mgr;
    unsigned long long timeout;
    unsigned long long repeat;
    timer_cb cb;
};

struct htimer_mgr_s
{
    htimer_t* root;
    timer_heap_compare less_than;
    unsigned long long timeout;
    unsigned long long size;
};

static inline int timer_heap_less_than(htimer_t* lhs, htimer_t* rhs) {
    return lhs->timeout < rhs->timeout;
}

static inline void timer_mgr_dump(
    htimer_mgr_t* timer_mgr, timer_debug_dump_cb dump) {
    if (timer_mgr->root) {
        dump(timer_mgr->root);
    }
}

static inline void timer_mgr_init(
    htimer_mgr_t* time_mgr, unsigned long long now_time) {
    memset(time_mgr, 0, sizeof(htimer_mgr_t));
    time_mgr->timeout = now_time;
    time_mgr->less_than = timer_heap_less_than;
}

static inline void timer_heap_swap(
    htimer_mgr_t* timer_mgr, htimer_t* parent, htimer_t* child) {
    // swap left/right/parent
    TIMER_HEAP_SWAP(htimer_t*, child->left, parent->left);
    TIMER_HEAP_SWAP(htimer_t*, child->right, parent->right);
    child->parent = parent->parent;
    parent->parent = child;

    htimer_t* sibling = NULL;
    if (child->left == child) {
        child->left = parent;
        sibling = child->right;
    }
    else {
        child->right = parent;
        sibling = child->left;
    }
    if (sibling) {
        sibling->parent = child;
    }

    if (parent->left) {
        parent->left->parent = parent;
    }
    if (parent->right) {
        parent->right->parent = parent;
    }

    if (child->parent == NULL) {
        timer_mgr->root = child;
    }
    else if (child->parent->left == parent) {
        child->parent->left = child;
    }
    else {
        child->parent->right = child;
    }
}

// replace 'replaced' with 'node'
static inline void timer_heap_replace(
    htimer_mgr_t* timer_mgr, htimer_t* replaced, htimer_t* node) {
    if (node->parent->left == node) {
        node->parent->left = NULL;
    }
    else {
        node->parent->right = NULL;
    }

    node->left = replaced->left;
    node->right = replaced->right;
    node->parent = replaced->parent;
    if (node->left) {
        node->left->parent = node;
    }
    if (node->right) {
        node->right->parent = node;
    }

    if (!replaced->parent) {
        timer_mgr->root = node;
    }
    else if (replaced->parent->left == replaced) {
        replaced->parent->left = node;
    }
    else {
        replaced->parent->right = node;
    }
    replaced->timer_mgr = NULL;
    replaced->left = replaced->right = replaced->parent = NULL;
}

// timeout: first trigger time
// repeat:  loop interval time after first trigger
static inline int timer_start(htimer_mgr_t* timer_mgr, htimer_t* timer,
    timer_cb cb, unsigned long long timeout, unsigned long long repeat) {
    if (NULL == cb) {
        return -1;
    }
    unsigned long long clamped_timeout = timer_mgr->timeout + timeout;
    if (clamped_timeout < timeout) {
        clamped_timeout = (unsigned long long)~0;
    }
    timer->cb = cb;
    timer->timeout = clamped_timeout;
    timer->repeat = repeat;
    timer->timer_mgr = timer_mgr;
    timer->left = timer->right = timer->parent = NULL;

    unsigned long long path = 0, n = 0, k = 0;
    for (k = 0, n = 1 + timer_mgr->size; n > 1; k += 1, n /= 2) {
        path = (path << 1) | (n & 1);
    }

    htimer_t **parent = &timer_mgr->root, **child = &timer_mgr->root;
    while (k) {
        parent = child;
        child = (path & 1) ? &(*child)->right : &(*child)->left;
        path >>= 1;
        k -= 1;
    }
    timer->parent = *parent;
    *child = timer;
    timer_mgr->size += 1;
    while (timer->parent && timer_mgr->less_than(timer, timer->parent)) {
        timer_heap_swap(timer_mgr, timer->parent, timer);
    }
    return 0;
}

static inline int timer_stop(htimer_t* timer) {
    htimer_mgr_t* timer_mgr = timer->timer_mgr;
    if (!timer_mgr || !timer_mgr->size) {
        return 0;
    }
    unsigned long long path = 0, n = 0, k = 0;
    for (k = 0, n = timer_mgr->size; n > 1; k += 1, n /= 2) {
        path = (path << 1) | (n & 1);
    }
    // find left-most node of the bottom row
    htimer_t* node = timer_mgr->root;
    while (k) {
        node = (path & 1) ? node->right : node->left;
        path >>= 1;
        k -= 1;
    }
    timer_mgr->size -= 1;

    if (node == timer) {
        // removing "left-most node of the bottom row" or the "last node"
        if (timer == timer_mgr->root) {
            timer_mgr->root = NULL;
        }
        else {
            assert(timer->parent);
            if (timer->parent->left == timer) {
                timer->parent->left = NULL;
            }
            else {
                timer->parent->right = NULL;
            }
            timer->parent = NULL;
        }
        timer->timer_mgr = NULL;
        return 0;
    }

    timer_heap_replace(timer_mgr, timer, node);

    // walk down heapify
    htimer_t* smallest = NULL;
    for (;;) {
        smallest = node;
        if (node->left && timer_mgr->less_than(node->left, smallest)) {
            smallest = node->left;
        }
        if (node->right && timer_mgr->less_than(node->right, smallest)) {
            smallest = node->right;
        }
        if (smallest == node) {
            break;
        }
        timer_heap_swap(timer_mgr, node, smallest);
    }
    // walk up heapify
    while (node->parent && timer_mgr->less_than(node, node->parent)) {
        timer_heap_swap(timer_mgr, node->parent, node);
    }
    return 0;
}

static inline int timer_again(htimer_mgr_t* timer_mgr, htimer_t* timer) {
    if (NULL == timer->cb) {
        return -1;
    }
    if (timer->repeat) {
        timer_stop(timer);
        timer_start(timer_mgr, timer, timer->cb, timer->repeat, timer->repeat);
    }
    return 0;
}

static inline int timer_tick(
    htimer_mgr_t* timer_mgr, unsigned long long now_time) {
    timer_mgr->timeout = now_time;
    for (;;) {
        htimer_t* timer = timer_mgr->root;
        if (!timer || timer->timeout > timer_mgr->timeout) {
            break;
        }
        timer_stop(timer);
        timer_again(timer_mgr, timer);
        timer->cb(timer);
    }
    return 0;
}